import torch
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torchvision
from PIL import Image
from xml.dom.minidom import parse
import utils
import transforms as T
from engine import train_one_epoch, evaluate
import xml.etree.cElementTree as ET
import collections
import pandas as pd
from torchvision.transforms import functional
import time
# from label_list import label_list

# label_list = ['background', 'bottle']


def predict_once(model, img):
    # 将模型切换成预测模式
    model.eval()

    img_tensor = functional.to_tensor(img)

    with torch.no_grad():
        # 下方是原注释
        '''
        prediction形如：
        [{'boxes': tensor([[1492.6672,  238.4670, 1765.5385,  315.0320],
        [ 887.1390,  256.8106, 1154.6687,  330.2953]], device='cuda:0'), 
        'labels': tensor([1, 1], device='cuda:0'), 
        'scores': tensor([1.0000, 1.0000], device='cuda:0')}]
        '''
        prediction = model([img_tensor.to('cuda')])

    # 先将预测结果打印出来看一下
    print(prediction)

    if 1:
        pass

    return prediction



def main(args):
    # 解析label——list文件
    with open("label_list.txt") as file:
        label_list = file.readlines()
    # map(str.rstrip, label_list)  # 去掉末尾的\n  # map中，传进去一个函数，而不是传进去一个函数的返回值
    label_list = [label.rstrip() for label in label_list]  # 去掉空字符
    label_list = [label for label in label_list if label != '']  # 去掉空行

    # 设置测试图片路径
    # root = r'test_images'
    # img_path = os.path.join(root, '1.jpg')

    # 加载模型
    # model = torch.load(r'models\model4.pkl')
    # model = torch.load(r'train1\trained_models\new_model.pkl')
    # model = torch.load(r'VOC2007\trained_models\new_model.pkl')
    model = torch.load(r'model.pkl')
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)

    # 设置为检测模式
    model.eval()
    # 读取图片
    # src_img = cv2.imread(img_path)
    # img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
    # 从摄像头读取图片
    #img_save_index = 0
    cap = cv2.VideoCapture(0)
    cv2.namedWindow('real_img', cv2.WINDOW_NORMAL)

    while (cap.isOpened()):
        # cv2.namedWindow()
            #cv2.namedWindow('real_img', cv2.WINDOW_NORMAL)
        _, frame = cap.read()
        src_img = frame
            #显示识别前图像

        cv2.imshow('real_img', src_img)

        img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
            # 调用模型预测

        prediction = predict_once(model=model, img=img)
            # print(prediction)

        boxes = prediction[0]['boxes']
        scores = prediction[0]['scores']
        labels = prediction[0]['labels']
        labels = [label_list[label] for label in labels]
        print(labels)
        for idx in range(boxes.shape[0]):
            x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
            if scores[idx] > 0.8:
                cv2.rectangle(src_img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), thickness=2)
                cv2.putText(src_img, labels[idx], (int(x1), int(y1)), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 255, 0))
        #显示识别后的图像
        cv2.imshow('result', src_img)
        input = cv2.waitKey(1) & 0xFF

        if input == ord('q'):
            break

    cap.release()
    cv2.destroyAllWindows()
    '''
  #点击'x'开始识别，点击'q'退出
    while(True):
        if cv2.waitKey(1) & 0xFF == ord('x'):
            cv2.imshow('result', src_img)
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break'''



    #if cv2.waitKey(1) & 0xFF == ord('q'):
    #cv2.waitKey()
    #cv2.destroyAllWindows()

"""
while (1):
    input = cv2.waitKey(1)&0xFF
    if input == ord('x'):
        cv2.imshow('result', src_img)
    if input == ord('q'):
        break
"""
        #cv2.waitKey()
    #cap.release






    # img = Image.open(os.path.join('test_images', '1.jpg')).convert("RGB")
    #
    # prediction = predict_once(model=model, img=img)
    #
    # for i in range(prediction[0]['boxes'].cpu().shape[0]):
    #     xmin = round(prediction[0]['boxes'][i][0].item())
    #     ymin = round(prediction[0]['boxes'][i][1].item())
    #     xmax = round(prediction[0]['boxes'][i][2].item())
    #     ymax = round(prediction[0]['boxes'][i][3].item())
    #
    #     label = prediction[0]['labels'][i].item()
    #
    #     cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (255, 0, 0), thickness=2)
    #
    # plt.figure(figsize=(20, 15))
    # plt.imshow(img)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description=__doc__)

    # parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
    # parser.add_argument('--dataset', default='coco', help='dataset')
    # parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
    # parser.add_argument('--device', default='cuda', help='device')
    # parser.add_argument('-b', '--batch-size', default=2, type=int,
    #                     help='images per gpu, the total batch size is $NGPU x batch_size')
    parser.add_argument('--epochs', default=26, type=int, metavar='N',
                        help='number of total epochs to run')
    # parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
    #                     help='number of data loading workers (default: 4)')
    # parser.add_argument('--lr', default=0.02, type=float,
    #                     help='initial learning rate, 0.02 is the default value for training '
    #                     'on 8 gpus and 2 images_per_gpu')
    # parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    #                     help='momentum')
    # parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
    #                     metavar='W', help='weight decay (default: 1e-4)',
    #                     dest='weight_decay')
    # parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
    # parser.add_argument('--output-dir', default='.', help='path where to save')
    # parser.add_argument('--resume', default='', help='resume from checkpoint')
    # parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    # parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
    # parser.add_argument(
    #     "--test-only",
    #     dest="test_only",
    #     help="Only test the model",
    #     action="store_true",
    # )
    # parser.add_argument(
    #     "--pretrained",
    #     dest="pretrained",
    #     help="Use pre-trained models from the modelzoo",
    #     action="store_true",
    # )

    # distributed training parameters
    # parser.add_argument('--world-size', default=1, type=int,
    #                     help='number of distributed processes')
    # parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()

    # if args.output_dir:
    #     utils.mkdir(args.output_dir)

    # 将这些命令行参数传入主函数中运行
    main(args)
